#!/usr/bin/env python3
"""
Single, reusable feature extractor for Python script corpora (28-script study)
-----------------------------------------------------------------------------
Extracts:
 A) Dependency matrix (binary 'dep:<lib>')
 B) Behavioral features with normalization:
    - Fully-qualified dotted calls 'call:<fq>'
    - Terminal methods 'term:<method>'
    - File/Network I/O tokens 'io:<...>'
    - Constructors 'ctor:<Class>'
 C) Simple intra-script call-graph metrics 'g:<metric>'

Defaults: min_df=2, TF-IDF (smooth idf, sublinear tf) + L2 normalization,
CSV separator=';' and decimal=',', and README emission.

USAGE
-----
# Input can be a folder or a ZIP that contains .py files (recursively).
python extract_features.py -i /path/to/scripts_or_zip -o /path/to/out

Optional flags:
  --min-df 3                 # raise rarity threshold
  --normalize l2counts       # use L2-normalized counts instead of TF-IDF
  --export-counts            # also export raw counts matrices
  --no-graph                 # skip graph metrics
  --sep ';' --decimal ','    # CSV formatting
  --exclude-dirs venv .git   # skip these directories by name
  --readme/--no-readme       # toggle README generation

Requirements: Python 3.9+, pandas, numpy
"""

from __future__ import annotations
import os, sys, io, re, ast, math, zipfile, argparse, shutil
from collections import Counter, defaultdict, deque
from typing import Dict, List, Set, Tuple
import pandas as pd
import numpy as np

VERSION = "1.0.0"

# -----------------------------
# Helpers
# -----------------------------
BUILTIN_EXCLUDE = {
    # Very common builtins to exclude from behavioral (except open kept in IO)
    "print", "len", "range", "enumerate", "list", "dict", "set", "tuple",
    "int", "float", "str", "bool", "sum", "min", "max", "zip", "map", "filter",
    "any", "all", "sorted", "open"  # 'open' handled separately under IO
}

CTOR_WHITELIST = {
    "BeautifulSoup", "Session", "Request", "Path", "PurePath",
    "Chrome", "Firefox", "WebDriverWait", "By",
    "Datetime", "Timedelta", "ExcelWriter", "Counter", "DefaultDict"
}

IO_ROOTS = {
    "os", "pathlib", "shutil", "glob", "requests", "urllib", "csv", "zipfile", "tarfile", "pandas"
}
IO_KEEP_NAMES = {"open"}  # builtins we keep under IO

import re as _re
CAMEL_CASE_RE = _re.compile(r'^[A-Z][A-Za-z0-9_]*$')

def is_camel_case(name: str) -> bool:
    return bool(CAMEL_CASE_RE.match(name))

def build_alias_map(tree: ast.AST) -> Dict[str, str]:
    alias_map: Dict[str, str] = {}
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                local = alias.asname if alias.asname else alias.name.split('.')[0]
                alias_map[local] = alias.name  # full module path
        elif isinstance(node, ast.ImportFrom):
            module = node.module or ""
            for alias in node.names:
                local = alias.asname if alias.asname else alias.name
                if module:
                    alias_map[local] = f"{module}.{alias.name}"
                else:
                    alias_map[local] = alias.name
    return alias_map

def extract_imports(tree: ast.AST) -> Set[str]:
    imports: Set[str] = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                root = alias.name.split('.')[0]
                imports.add(root)
        elif isinstance(node, ast.ImportFrom):
            if node.module:
                root = node.module.split('.')[0]
                imports.add(root)
    return imports

def get_attr_chain(node: ast.AST) -> List[str]:
    parts: List[str] = []
    cur = node
    while isinstance(cur, ast.Attribute):
        parts.append(cur.attr)
        cur = cur.value
    if isinstance(cur, ast.Name):
        parts.append(cur.id)
    elif isinstance(cur, ast.Call):
        parts.append("<call>")
    elif isinstance(cur, ast.Subscript):
        parts.append("<subscript>")
    parts.reverse()
    return parts

def resolve_root(name: str, alias_map: Dict[str, str]) -> str:
    return alias_map.get(name, name)

def extract_calls_and_terms(tree: ast.AST, alias_map: Dict[str, str]) -> Tuple[List[str], List[str]]:
    fq_calls: List[str] = []
    terms: List[str] = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            func = node.func
            chain = get_attr_chain(func)
            if not chain:
                continue
            term = chain[-1]
            if term and term not in BUILTIN_EXCLUDE:
                terms.append(f"term:{term}")
            root = chain[0]
            resolved_root = resolve_root(root, alias_map) if isinstance(root, str) else None
            if isinstance(root, str) and resolved_root:
                tail = chain[1:] if len(chain) > 1 else []
                fq = ".".join([resolved_root] + tail)
                fq_calls.append(f"call:{fq}")
    return fq_calls, terms

def extract_io_tokens(fq_calls: List[str], terms: List[str]) -> List[str]:
    io_tokens: List[str] = []
    for t in terms:
        if t == "term:open":
            io_tokens.append("io:open")
    for c in fq_calls:
        body = c.split("call:", 1)[-1]
        root = body.split(".", 1)[0]
        if root in IO_ROOTS:
            io_tokens.append(f"io:{body}")
    return io_tokens

def extract_constructors(tree: ast.AST, alias_map: Dict[str, str]) -> List[str]:
    ctors: List[str] = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            chain = get_attr_chain(node.func)
            if not chain:
                continue
            term = chain[-1]
            root = chain[0] if chain else ""
            resolved_root = resolve_root(root, alias_map)
            if term and (is_camel_case(term) or term in CTOR_WHITELIST):
                ctors.append(f"ctor:{term}")
            elif resolved_root and resolved_root.split(".")[-1] and is_camel_case(resolved_root.split(".")[-1]):
                ctors.append(f"ctor:{resolved_root.split('.')[-1]}")
    return ctors

def extract_call_graph_stats(tree: ast.AST) -> Dict[str, float]:
    func_bodies: Dict[str, ast.AST] = {}
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            func_bodies[node.name] = node

    local_names = set(func_bodies.keys())
    edges = set()
    total_calls = 0
    local_calls = 0

    class CallVisitor(ast.NodeVisitor):
        def __init__(self, current_func: str):
            self.current_func = current_func
        def visit_Call(self, node: ast.Call):
            nonlocal total_calls, local_calls
            total_calls += 1
            if isinstance(node.func, ast.Name) and node.func.id in local_names:
                edges.add((self.current_func, node.func.id))
                local_calls += 1
            self.generic_visit(node)

    for fname, fnode in func_bodies.items():
        v = CallVisitor(fname)
        v.visit(fnode)

    n_nodes = len(local_names)
    n_edges = len(edges)
    avg_outdeg = n_edges / n_nodes if n_nodes > 0 else 0.0
    density = (n_edges / (n_nodes * (n_nodes - 1))) if n_nodes > 1 else 0.0
    pct_external_calls = (1.0 - (local_calls / total_calls)) if total_calls > 0 else 0.0

    from collections import Counter as _Counter
    outdeg = _Counter([a for a, b in edges])
    indeg = _Counter([b for a, b in edges])
    max_outdeg = max(outdeg.values()) if outdeg else 0
    max_indeg = max(indeg.values()) if indeg else 0

    undirected = defaultdict(set)
    for a, b in edges:
        undirected[a].add(b)
        undirected[b].add(a)
    for n in local_names:
        undirected.setdefault(n, set())

    visited = set()
    comp_count = 0
    for n in local_names:
        if n not in visited:
            comp_count += 1
            dq = deque([n])
            visited.add(n)
            while dq:
                cur = dq.popleft()
                for nei in undirected[cur]:
                    if nei not in visited:
                        visited.add(nei)
                        dq.append(nei)

    return {
        "g:n_nodes": float(n_nodes),
        "g:n_edges": float(n_edges),
        "g:avg_outdeg": float(avg_outdeg),
        "g:density": float(density),
        "g:pct_external_calls": float(pct_external_calls),
        "g:max_outdeg": float(max_outdeg),
        "g:max_indeg": float(max_indeg),
        "g:component_count": float(comp_count),
    }

def build_binary_dep_matrix(dep_records: List[Tuple[str, Set[str]]]) -> pd.DataFrame:
    all_mods = sorted({m for _, mods in dep_records for m in mods})
    data = []
    idx = []
    for script, mods in dep_records:
        row = [1 if m in mods else 0 for m in all_mods]
        data.append(row)
        idx.append(script)
    df = pd.DataFrame(data, index=idx, columns=[f"dep:{m}" for m in all_mods])
    return df

def vectorize_tokens_min_df(
    docs: List[Tuple[str, List[str]]],
    min_df:int=2,
    normalize:str="tfidf",   # 'tfidf' or 'l2counts'
    return_counts:bool=False
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    from collections import Counter as _Counter
    dfreq = _Counter()
    for _, tokens in docs:
        dfreq.update(set(tokens))
    vocab = sorted([t for t, df in dfreq.items() if df >= min_df])
    idx = [s for s, _ in docs]

    if not vocab:
        empty_counts = pd.DataFrame(index=idx)
        empty_norm = pd.DataFrame(index=idx)
        return (empty_counts if return_counts else empty_counts), empty_norm

    tok_index = {t:i for i,t in enumerate(vocab)}
    mat_counts = np.zeros((len(docs), len(vocab)), dtype=float)
    for r, (script, tokens) in enumerate(docs):
        c = _Counter(tokens)
        for t, v in c.items():
            j = tok_index.get(t)
            if j is not None:
                mat_counts[r, j] = float(v)

    counts_df = pd.DataFrame(mat_counts, index=idx, columns=vocab)

    if normalize == "l2counts":
        X = counts_df.to_numpy(dtype=float)
        norms = np.linalg.norm(X, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1.0, norms)
        Xn = X / norms
        norm_df = pd.DataFrame(Xn, index=idx, columns=vocab)
    else:
        N = len(docs)
        df_arr = np.array([dfreq[t] for t in vocab], dtype=float)
        idf = np.log((1.0 + N) / (1.0 + df_arr)) + 1.0
        tf = counts_df.to_numpy(dtype=float)
        tf = np.where(tf > 0, 1.0 + np.log(tf), 0.0)
        tfidf = tf * idf.reshape(1, -1)
        norms = np.linalg.norm(tfidf, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1.0, norms)
        Xn = tfidf / norms
        norm_df = pd.DataFrame(Xn, index=idx, columns=vocab)

    return (counts_df if return_counts else pd.DataFrame(index=idx)), norm_df

def safe_read_text(path: str) -> str:
    try:
        with open(path, "r", encoding="utf-8") as f:
            return f.read()
    except UnicodeDecodeError:
        with open(path, "r", encoding="latin-1", errors="ignore") as f:
            return f.read()

def unzip_if_needed(input_path: str, work_dir: str) -> str:
    if os.path.isdir(input_path):
        return input_path
    if zipfile.is_zipfile(input_path):
        if os.path.exists(work_dir):
            shutil.rmtree(work_dir)
        os.makedirs(work_dir, exist_ok=True)
        with zipfile.ZipFile(input_path, 'r') as zf:
            zf.extractall(work_dir)
        return work_dir
    raise ValueError(f"Input path '{input_path}' is neither a directory nor a ZIP.")

def find_py_files(root: str, exclude_dirs: List[str]) -> List[str]:
    out: List[str] = []
    ex = set(exclude_dirs or [])
    for dirpath, dirnames, filenames in os.walk(root):
        dirnames[:] = [d for d in dirnames if d not in ex]
        for fn in filenames:
            if fn.endswith(".py"):
                out.append(os.path.join(dirpath, fn))
    return sorted(out)

def save_csv(df: pd.DataFrame, path: str, sep: str, decimal: str):
    df.sort_index(axis=0, inplace=False).to_csv(path, sep=sep, decimal=decimal, encoding="utf-8")

def main():
    import argparse
    ap = argparse.ArgumentParser(description="Single, reusable feature extractor for script corpora")
    ap.add_argument("-i", "--input", required=True, help="Folder OR ZIP containing .py files (recursively)")
    ap.add_argument("-o", "--out", required=True, help="Output folder for CSVs")
    ap.add_argument("--min-df", type=int, default=2, help="Minimum document frequency threshold (default: 2)")
    ap.add_argument("--normalize", choices=["tfidf","l2counts"], default="tfidf", help="Normalization method for behavioral features")
    ap.add_argument("--export-counts", action="store_true", help="Also export raw counts matrices")
    ap.add_argument("--no-graph", action="store_true", help="Skip computing graph stats")
    ap.add_argument("--sep", default=";", help="CSV separator (default: ;)")
    ap.add_argument("--decimal", default=",", help="CSV decimal mark (default: ,)")
    ap.add_argument("--exclude-dirs", nargs="*", default=["venv", ".venv", ".git", "__pycache__"], help="Directory names to exclude")
    ap.add_argument("--readme", dest="readme", action="store_true", help="Write README.txt (default on)")
    ap.add_argument("--no-readme", dest="readme", action="store_false", help="Disable README.txt")
    ap.set_defaults(readme=True)
    args = ap.parse_args()

    os.makedirs(args.out, exist_ok=True)
    work_dir = os.path.join(args.out, "_unzipped_tmp")
    src_root = unzip_if_needed(args.input, work_dir)
    print(f"[extractor] Source root: {src_root}")
    py_files = find_py_files(src_root, args.exclude_dirs)
    if not py_files:
        print("[extractor] No .py files found. Exiting.", file=sys.stderr)
        sys.exit(1)
    print(f"[extractor] Found {len(py_files)} Python files.")

    dep_records: List[Tuple[str, Set[str]]] = []
    calls_docs: List[Tuple[str, List[str]]] = []
    terms_docs: List[Tuple[str, List[str]]] = []
    io_docs: List[Tuple[str, List[str]]] = []
    ctor_docs: List[Tuple[str, List[str]]] = []
    graph_rows: List[Dict[str, float]] = []

    for path in py_files:
        rel = os.path.relpath(path, src_root).replace("\\", "/")
        src = safe_read_text(path)
        try:
            tree = ast.parse(src)
        except SyntaxError as e:
            print(f"[warn] SyntaxError in {rel}: {e}", file=sys.stderr)
            continue

        alias_map = build_alias_map(tree)
        imports = extract_imports(tree)
        fq_calls, terms = extract_calls_and_terms(tree, alias_map)
        io_tokens = extract_io_tokens(fq_calls, terms)
        ctors = extract_constructors(tree, alias_map)

        dep_records.append((rel, imports))
        calls_docs.append((rel, fq_calls))
        terms_docs.append((rel, terms))
        io_docs.append((rel, io_tokens))
        ctor_docs.append((rel, ctors))

        if not args.no_graph:
            gstats = extract_call_graph_stats(tree)
            g = {"script": rel}
            g.update(gstats)
            graph_rows.append(g)

    dep_df = build_binary_dep_matrix(dep_records)
    calls_counts_df, calls_norm_df = vectorize_tokens_min_df(calls_docs, min_df=args.min_df, normalize=args.normalize, return_counts=args.export_counts)
    terms_counts_df, terms_norm_df = vectorize_tokens_min_df(terms_docs, min_df=args.min_df, normalize=args.normalize, return_counts=args.export_counts)
    io_counts_df, io_norm_df = vectorize_tokens_min_df(io_docs, min_df=args.min_df, normalize=args.normalize, return_counts=args.export_counts)
    ctor_counts_df, ctor_norm_df = vectorize_tokens_min_df(ctor_docs, min_df=args.min_df, normalize=args.normalize, return_counts=args.export_counts)

    if not args.no_graph:
        graph_df = pd.DataFrame(graph_rows).set_index("script").sort_index()
    else:
        graph_df = pd.DataFrame(index=[s for s,_ in dep_records])

    save_csv(dep_df, os.path.join(args.out, "dep_matrix.csv"), args.sep, args.decimal)
    save_csv(calls_norm_df, os.path.join(args.out, "behavior_calls_tfidf.csv" if args.normalize=="tfidf" else "behavior_calls_l2counts.csv"), args.sep, args.decimal)
    save_csv(terms_norm_df, os.path.join(args.out, "behavior_terms_tfidf.csv" if args.normalize=="tfidf" else "behavior_terms_l2counts.csv"), args.sep, args.decimal)
    save_csv(io_norm_df, os.path.join(args.out, "fileio_ops_tfidf.csv" if args.normalize=="tfidf" else "fileio_ops_l2counts.csv"), args.sep, args.decimal)
    save_csv(ctor_norm_df, os.path.join(args.out, "constructors_tfidf.csv" if args.normalize=="tfidf" else "constructors_l2counts.csv"), args.sep, args.decimal)
    if not args.no_graph:
        save_csv(graph_df, os.path.join(args.out, "graph_stats.csv"), args.sep, args.decimal)

    if args.export_counts:
        save_csv(calls_counts_df, os.path.join(args.out, "behavior_calls_counts.csv"), args.sep, args.decimal)
        save_csv(terms_counts_df, os.path.join(args.out, "behavior_terms_counts.csv"), args.sep, args.decimal)
        save_csv(io_counts_df, os.path.join(args.out, "fileio_ops_counts.csv"), args.sep, args.decimal)
        save_csv(ctor_counts_df, os.path.join(args.out, "constructors_counts.csv"), args.sep, args.decimal)

    if args.readme:
        readme = f"""Feature export generated by extract_features.py v{VERSION}
--------------------------------------------------------------
Input      : {args.input}
Source root: {src_root}
Output dir : {args.out}

Settings
- min_df           : {args.min_df}
- normalize        : {args.normalize}
- export_counts    : {args.export_counts}
- compute_graph    : {not args.no_graph}
- CSV sep/decimal  : '{args.sep}' / '{args.decimal}'

Outputs
- dep_matrix.csv
- behavior_calls_{'tfidf' if args.normalize=='tfidf' else 'l2counts'}.csv
- behavior_terms_{'tfidf' if args.normalize=='tfidf' else 'l2counts'}.csv
- fileio_ops_{'tfidf' if args.normalize=='tfidf' else 'l2counts'}.csv
- constructors_{'tfidf' if args.normalize=='tfidf' else 'l2counts'}.csv
- graph_stats.csv             (if graph enabled)
- *_counts.csv                (if --export-counts)

Notes
- Dotted-call reconstruction uses alias resolution (e.g., 'rq.get' -> 'requests.get').
- 'term:<method>' bags are alias-agnostic and robust when modules can't be resolved.
- I/O tokens include builtin 'open' plus dotted calls rooted in: {sorted(IO_ROOTS)}
- Constructors are detected by CamelCase/whitelist and alias resolution.
- Graph metrics are a light, intra-script view; external calls are summarized via g:pct_external_calls.
"""
        with open(os.path.join(args.out, "README.txt"), "w", encoding="utf-8") as f:
            f.write(readme)

    if os.path.isdir(work_dir) and work_dir != args.input:
        shutil.rmtree(work_dir, ignore_errors=True)

    print("[extractor] Done.")

if __name__ == "__main__":
    main()
